#!/usr/bin/env python3
import argparse, math, os
import numpy as np
import pandas as pd
from astropy.io import fits
from multiprocessing import Pool, cpu_count

# ---------- helpers (same as original) ----------

def edges_from_csv(s: str):
    return [float(x) for x in s.split(",")]

def midpoints(edges):
    e = np.array(edges, dtype=float)
    return 0.5 * (e[:-1] + e[1:])

def assign_bin(x, edges):
    i = np.digitize([float(x)], edges, right=False)[0] - 1
    return i if 0 <= i < len(edges) - 1 else None

def bin_label(edges, i):
    return f"{edges[i]}–{edges[i+1]}"

def make_colmap(tbl):
    names = list(tbl.columns.names) if hasattr(tbl.columns, "names") else [c.name for c in tbl.columns]
    return {n.lower(): n for n in names}

def pick(colmap, *cands):
    for c in cands:
        n = colmap.get(c.lower())
        if n is not None:
            return n
    raise KeyError(f"Missing expected column. Tried {cands}, have {list(colmap.values())}")

# ---------- globals for workers ----------

_src_ra = _src_de = _e1 = _e2 = _w = _z = _m = None
_b_edges = _nb = None
_rg_edges = _ms_edges = None
_min_zsep = None
_use_m_corr = False

def _init_worker(src_ra, src_de, e1, e2, w, z, m,
                 b_edges, nb, rg_edges, ms_edges, min_zsep, use_m_corr):
    global _src_ra, _src_de, _e1, _e2, _w, _z, _m
    global _b_edges, _nb, _rg_edges, _ms_edges, _min_zsep, _use_m_corr
    _src_ra = src_ra
    _src_de = src_de
    _e1 = e1
    _e2 = e2
    _w = w
    _z = z
    _m = m
    _b_edges = b_edges
    _nb = nb
    _rg_edges = rg_edges
    _ms_edges = ms_edges
    _min_zsep = min_zsep
    _use_m_corr = use_m_corr

def _process_lens(lens_tuple):
    """
    lens_tuple: (ra_deg, dec_deg, z_lens, RG_kpc, Mstar_log10)
    Returns: (key, s_we, s_w, 1) or None
    """
    ra_deg, dec_deg, z_lens, RG_kpc, Mstar_log10 = lens_tuple

    rgi = assign_bin(RG_kpc, _rg_edges)
    msi = assign_bin(Mstar_log10, _ms_edges)
    if rgi is None or msi is None:
        return None

    rg_lab = bin_label(_rg_edges, rgi)
    ms_lab = bin_label(_ms_edges, msi)
    key = (rg_lab, ms_lab)

    max_deg = float(_b_edges[-1]) / 3600.0
    cosd = math.cos(math.radians(float(dec_deg)))
    dra = (_src_ra - float(ra_deg)) * cosd
    dde = (_src_de - float(dec_deg))
    sel0 = (np.abs(dra) < max_deg) & (np.abs(dde) < max_deg)

    if _z is not None:
        sel = sel0 & (_z >= (float(z_lens) + _min_zsep))
    else:
        sel = sel0

    if not np.any(sel):
        return key, np.zeros(_nb), np.zeros(_nb), 1

    dx = dra[sel] * 3600.0  # arcsec
    dy = dde[sel] * 3600.0
    r = np.hypot(dx, dy)
    phi = np.arctan2(dy, dx)
    cos2, sin2 = np.cos(2 * phi), np.sin(2 * phi)
    e_t = -(_e1[sel] * cos2 + _e2[sel] * sin2)

    ww = _w[sel]
    if _use_m_corr and (_m is not None):
        ww = ww / np.clip(1.0 + _m[sel], 0.5, 2.0)

    idx = np.digitize(r, _b_edges) - 1
    good = (idx >= 0) & (idx < _nb)
    if not np.any(good):
        return key, np.zeros(_nb), np.zeros(_nb), 1

    bins = idx[good]
    s_we = np.bincount(bins, weights=ww[good] * e_t[good], minlength=_nb)[:_nb]
    s_w  = np.bincount(bins, weights=ww[good], minlength=_nb)[:_nb]

    return key, s_we, s_w, 1

# ---------- main ----------

def main():
    ap = argparse.ArgumentParser(
        description="KiDS -> prestacked_stacks.csv (tangential shear vs b, grouped by size/mass bins) [parallel]"
    )
    ap.add_argument("--kids", required=True)
    ap.add_argument("--lenses", required=True,
                    help="CSV with lens_id,ra_deg,dec_deg,z_lens,RG_kpc,Mstar_log10")
    ap.add_argument("--out", default="data/prestacked_stacks.csv")
    ap.add_argument("--out-meta", default="data/prestacked_meta.csv")
    ap.add_argument("--rg-bins", default="5,7.5,10,12.5,15")
    ap.add_argument("--mstar-bins", default="10.2,10.5,10.8,11.1")
    ap.add_argument("--b-bins-arcsec",
                    default="10,15,22,32,46,66,95,137,198,285,410,592,855,1236,1787,2583")
    ap.add_argument("--min-zsep", type=float, default=0.1)
    ap.add_argument("--use-m-corr", action="store_true")
    ap.add_argument("--max-lenses", type=int, default=None)
    ap.add_argument("--n-proc", type=int, default=None,
                    help="number of parallel processes (default: all cores)")
    args = ap.parse_args()

    rg_edges = edges_from_csv(args.rg_bins)
    ms_edges = edges_from_csv(args.mstar_bins)
    b_edges = np.array(edges_from_csv(args.b_bins_arcsec), dtype=float)
    b_mids = midpoints(b_edges)
    nb = len(b_mids)

    L = pd.read_csv(args.lenses)
    need = ["ra_deg", "dec_deg", "z_lens", "RG_kpc", "Mstar_log10"]
    for c in need:
        if c not in L.columns:
            raise SystemExit(f"Missing column '{c}' in {args.lenses}")
    L = L.dropna(subset=need)
    if args.max_lenses:
        L = L.head(args.max_lenses)
    n_lenses = len(L)
    print(f"[info] Using {n_lenses} lenses")

    # KiDS catalog
    hdul = fits.open(args.kids, memmap=True)
    tbl = hdul[1]
    data = tbl.data
    colmap = make_colmap(tbl)

    ra_c  = pick(colmap, "RAJ2000", "ALPHA_J2000", "RA", "ALPHAWIN_J2000")
    de_c  = pick(colmap, "DECJ2000", "DELTA_J2000", "DEC", "DELTAWIN_J2000")
    e1_c  = pick(colmap, "e1", "ELLIP1")
    e2_c  = pick(colmap, "e2", "ELLIP2")
    w_c   = pick(colmap, "weight", "W")

    z_c   = None
    for zcand in ("z_B", "Z_B", "PHOTOZ", "ZPHOT"):
        if zcand.lower() in colmap:
            z_c = colmap[zcand.lower()]
            break

    m_c   = None
    for mcand in ("m", "MCOR", "m_corr", "M"):
        if mcand.lower() in colmap:
            m_c = colmap[mcand.lower()]
            break

    src_ra = data[ra_c].astype(float)
    src_de = data[de_c].astype(float)
    e1 = data[e1_c].astype(float)
    e2 = data[e2_c].astype(float)
    w  = data[w_c].astype(float)
    z  = data[z_c].astype(float) if z_c else None
    m  = data[m_c].astype(float) if m_c else None

    stacks = {}  # (rg_label, ms_label) -> {"sum_we_t":..., "sum_w":..., "nL":int}

    # lens tuples for workers
    lens_tuples = list(
        zip(L["ra_deg"].values,
            L["dec_deg"].values,
            L["z_lens"].values,
            L["RG_kpc"].values,
            L["Mstar_log10"].values)
    )

    n_proc = args.n_proc or cpu_count()
    print(f"[info] Spawning {n_proc} processes")

    with Pool(
        processes=n_proc,
        initializer=_init_worker,
        initargs=(src_ra, src_de, e1, e2, w, z, m,
                  b_edges, nb, rg_edges, ms_edges,
                  args.min_zsep, args.use_m_corr),
    ) as pool:
        processed = 0
        for res in pool.imap_unordered(_process_lens, lens_tuples, chunksize=100):
            processed += 1
            if processed % 1000 == 0:
                print(f"[info] {processed}/{n_lenses} lenses processed")

            if res is None:
                continue
            key, s_we, s_w, nL = res
            acc = stacks.setdefault(
                key,
                {"sum_we_t": np.zeros(nb), "sum_w": np.zeros(nb), "nL": 0},
            )
            acc["sum_we_t"] += s_we
            acc["sum_w"]    += s_w
            acc["nL"]       += nL

    # write outputs (same as original)
    rows, meta = [], []
    for (rg_lab, ms_lab), acc in stacks.items():
        s_we, s_w, nL = acc["sum_we_t"], acc["sum_w"], acc["nL"]
        with np.errstate(divide="ignore", invalid="ignore"):
            gamma_t = np.where(s_w > 0, s_we / s_w, np.nan)
        for k, gt in enumerate(gamma_t):
            if not np.isfinite(gt):
                continue
            rows.append(
                {
                    "stack_id": f"{rg_lab}_{ms_lab}",
                    "R_G_bin": rg_lab,
                    "Mstar_bin": ms_lab,
                    "b": float(b_mids[k]),
                    "gamma_t": float(gt),
                    "weight": float(s_w[k]),
                }
            )
        meta.append(
            {
                "stack_id": f"{rg_lab}_{ms_lab}",
                "n_lenses": int(nL),
                "R_G_mean_kpc": np.nan,
            }
        )

    pd.DataFrame(rows).to_csv(args.out, index=False)
    pd.DataFrame(meta).to_csv(args.out_meta, index=False)
    print(f"Wrote {args.out} ({len(rows)} rows); meta {args.out_meta} ({len(meta)} stacks).")

if __name__ == "__main__":
    main()
